{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification - Adult Census using Vowpal Wabbit in SynapseML\n",
    "\n",
    "In this example, we predict incomes from the *Adult Census* dataset using Vowpal Wabbit (VW) classifier in SynapseML.\n",
    "First, we read the data and split it into train and test sets as in this [example](https://github.com/Microsoft/SynapseML/blob/master/notebooks/Classification%20-%20Adult%20Census.ipynb\n",
    ")."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = spark.read.parquet(\n",
    "    \"wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet\"\n",
    ")\n",
    "data = data.select([\"education\", \"marital-status\", \"hours-per-week\", \"income\"])\n",
    "train, test = data.randomSplit([0.75, 0.25], seed=123)\n",
    "train.limit(10).toPandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we define a pipeline that includes feature engineering and training of a VW classifier. We use a featurizer provided by VW that hashes the feature names. \n",
    "Note that VW expects classification labels being -1 or 1. Thus, the income category is mapped to this space before feeding training data into the pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import when, col\n",
    "from pyspark.ml import Pipeline\n",
    "from synapse.ml.vw import VowpalWabbitFeaturizer, VowpalWabbitClassifier\n",
    "\n",
    "# Define classification label\n",
    "train = (\n",
    "    train.withColumn(\"label\", when(col(\"income\").contains(\"<\"), 0.0).otherwise(1.0))\n",
    "    .repartition(1)\n",
    "    .cache()\n",
    ")\n",
    "print(train.count())\n",
    "\n",
    "# Specify featurizer\n",
    "vw_featurizer = VowpalWabbitFeaturizer(\n",
    "    inputCols=[\"education\", \"marital-status\", \"hours-per-week\"], outputCol=\"features\"\n",
    ")\n",
    "\n",
    "# Define VW classification model\n",
    "args = \"--loss_function=logistic --quiet --holdout_off\"\n",
    "vw_model = VowpalWabbitClassifier(\n",
    "    featuresCol=\"features\", labelCol=\"label\", passThroughArgs=args, numPasses=10\n",
    ")\n",
    "\n",
    "# Create a pipeline\n",
    "vw_pipeline = Pipeline(stages=[vw_featurizer, vw_model])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we are ready to train the model by fitting the pipeline with the training data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model\n",
    "vw_trained = vw_pipeline.fit(train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the model is trained, we apply it to predict the income of each sample in the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Making predictions\n",
    "test = test.withColumn(\"label\", when(col(\"income\").contains(\"<\"), 0.0).otherwise(1.0))\n",
    "prediction = vw_trained.transform(test)\n",
    "prediction.limit(10).toPandas()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we evaluate the model performance using `ComputeModelStatistics` function which will compute confusion matrix, accuracy, precision, recall, and AUC by default for classification models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.train import ComputeModelStatistics\n",
    "\n",
    "metrics = ComputeModelStatistics(\n",
    "    evaluationMetric=\"classification\", labelCol=\"label\", scoredLabelsCol=\"prediction\"\n",
    ").transform(prediction)\n",
    "metrics.toPandas()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "name": "vw_mmlspark_adult_census",
  "notebookId": 453327986061569,
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
